{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 9.树回归算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.1.读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadDataSet(fileName):\n",
    "    \"\"\"\n",
    "    读取数据\n",
    "    参数：\n",
    "        fileName -- 文件名\n",
    "    参数：\n",
    "        dataMat -- 数据矩阵\n",
    "    \"\"\"\n",
    "    # 新建数据矩阵\n",
    "    dataMat = []\n",
    "    fr = open(fileName)\n",
    "    # 以制表符分割\n",
    "    for line in fr.readlines():\n",
    "        curLine = line.strip().split('\\t')\n",
    "        fltLine = list(map(float,curLine))\n",
    "        dataMat.append(fltLine)\n",
    "    return dataMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "def binSplitDataSet(dataSet, feature, value):\n",
    "    \"\"\"\n",
    "    二分割数据\n",
    "    参数：\n",
    "        dataSet -- 数据集\n",
    "        feature -- 分割特征\n",
    "        value -- 分割点\n",
    "    返回：\n",
    "        mat0 -- 矩阵1\n",
    "        mat1 -- 矩阵2\n",
    "    \"\"\"\n",
    "    #if len(nonzero(dataSet[:,feature] > value)[0]) == 0:\n",
    "    #    return array([]), array([])\n",
    "    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]\n",
    "    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]\n",
    "    return mat0,mat1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.2.CART算法\n",
    "### 9.2.1.支持函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def regLeaf(dataSet):\n",
    "    return mean(dataSet[:,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def regErr(dataSet):\n",
    "    return var(dataSet[:,-1]) * shape(dataSet)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):\n",
    "    \"\"\"\n",
    "    选择最佳分割点\n",
    "    参数：\n",
    "        dataSet -- 数据集\n",
    "        leafTypr -- 叶子类型\n",
    "        errType -- 误差类型\n",
    "        ops -- 参数元组\n",
    "    返回：\n",
    "        bestIndex -- 最佳类型索引\n",
    "        bestValue -- 最佳值\n",
    "    \"\"\"\n",
    "    # 读取元组中的参数\n",
    "    tolS = ops[0]; tolN = ops[1]\n",
    "    # 如果该类型下只有一个值，无法继续划分，则返回该值\n",
    "    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:\n",
    "        #print(\"end1\")\n",
    "        return None, leafType(dataSet)\n",
    "    # 数据维度\n",
    "    m,n = shape(dataSet)\n",
    "    # 最佳分割类型由误差决定\n",
    "    S = errType(dataSet)\n",
    "    bestS = inf; bestIndex = 0; bestValue = 0\n",
    "    # 遍历所有特征\n",
    "    for featIndex in range(n-1):\n",
    "        # 遍历所有值\n",
    "        for splitVal in set(dataSet[:, featIndex].A.squeeze().tolist()):\n",
    "            # 二分割\n",
    "            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)\n",
    "            # 如果分割后矩阵的维度不满足要求，则continue\n",
    "            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): \n",
    "                #print(\"continue\")\n",
    "                continue\n",
    "            # 计算总误差\n",
    "            #print(\"do\")\n",
    "            newS = errType(mat0) + errType(mat1)\n",
    "            # 如果误差最佳\n",
    "            if newS < bestS: \n",
    "                # 记录相关参数\n",
    "                bestIndex = featIndex\n",
    "                bestValue = splitVal\n",
    "                bestS = newS\n",
    "    # 如果分割的效果提升没有超过阈值，就不进行分割\n",
    "    if (S - bestS) < tolS: \n",
    "        #print(\"end2\")\n",
    "        #print(f\"S = {S}, bestS = {bestS}, tolS = {tolS}\")\n",
    "        return None, leafType(dataSet)\n",
    "    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)\n",
    "    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):\n",
    "        #print(\"end3\")\n",
    "        return None, leafType(dataSet)\n",
    "    #print(\"end4\")\n",
    "    return bestIndex,bestValue"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 9.2.2.生成树的算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):\n",
    "    \"\"\"\n",
    "    建立树\n",
    "    参数：\n",
    "        dataSet -- 数据集\n",
    "        leafType -- 叶子类型\n",
    "        errType -- 误差类型\n",
    "        ops -- 参数元组\n",
    "    返回：\n",
    "        retTree -- 返回的树\n",
    "    \"\"\"\n",
    "    # 选择最佳分割点\n",
    "    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)\n",
    "    # 如果达到迭代停止条件，返回val\n",
    "    if feat == None: return val\n",
    "    # 新建子树字典\n",
    "    retTree = {}\n",
    "    # 分割特征的索引\n",
    "    retTree['spInd'] = feat\n",
    "    # 分割点的值\n",
    "    retTree['spVal'] = val\n",
    "    # 二分割\n",
    "    lSet, rSet = binSplitDataSet(dataSet, feat, val)\n",
    "    # 左右子树\n",
    "    retTree['left'] = createTree(lSet, leafType, errType, ops)\n",
    "    retTree['right'] = createTree(rSet, leafType, errType, ops)\n",
    "    return retTree "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'spInd': 0,\n",
       " 'spVal': 0.48813,\n",
       " 'left': 1.0180967672413792,\n",
       " 'right': -0.04465028571428572}"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myDat = loadDataSet(\"ex00.txt\")\n",
    "myMat = mat(myDat)\n",
    "createTree(myMat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在ex0.txt上测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'spInd': 1,\n",
       " 'spVal': 0.39435,\n",
       " 'left': {'spInd': 1,\n",
       "  'spVal': 0.582002,\n",
       "  'left': {'spInd': 1,\n",
       "   'spVal': 0.797583,\n",
       "   'left': 3.9871632,\n",
       "   'right': 2.9836209534883724},\n",
       "  'right': 1.980035071428571},\n",
       " 'right': {'spInd': 1,\n",
       "  'spVal': 0.197834,\n",
       "  'left': 1.0289583666666666,\n",
       "  'right': -0.023838155555555553}}"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myDat1 = loadDataSet(\"ex0.txt\")\n",
    "myMat1 = mat(myDat1)\n",
    "createTree(myMat1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.3.树剪枝\n",
    "### 9.3.1.支持函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "def isTree(obj):\n",
    "    return (type(obj).__name__=='dict')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getMean(tree):\n",
    "    if isTree(tree['right']): tree['right'] = getMean(tree['right'])\n",
    "    if isTree(tree['left']): tree['left'] = getMean(tree['left'])\n",
    "    return (tree['left']+tree['right'])/2.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 9.3.2.剪枝函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prune(tree, testData):\n",
    "    \"\"\"\n",
    "    剪枝函数\n",
    "    参数：\n",
    "        tree -- 待剪枝的树\n",
    "        testData -- 测试数据\n",
    "    返回：\n",
    "        treeMean -- 合并的结果\n",
    "        或\n",
    "        tree -- 不需要剪枝\n",
    "    \"\"\"\n",
    "    # 如果没有测试数据，就直接把整棵树合并\n",
    "    if shape(testData)[0] == 0: return getMean(tree)\n",
    "    # 如果没有树可以合并，则分割节点\n",
    "    if (isTree(tree['right']) or isTree(tree['left'])):\n",
    "        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])\n",
    "    # 递归分割左右子树\n",
    "    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)\n",
    "    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)\n",
    "    # 如果两棵树都是叶子节点，则判断是否要合并\n",
    "    if not isTree(tree['left']) and not isTree(tree['right']):\n",
    "        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])\n",
    "        # 不合并的误差\n",
    "        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\\\n",
    "            sum(power(rSet[:,-1] - tree['right'],2))\n",
    "        treeMean = (tree['left']+tree['right'])/2.0\n",
    "        # 合并误差\n",
    "        errorMerge = sum(power(testData[:,-1] - treeMean,2))\n",
    "        # 如果合并后，误差减小，则执行合并\n",
    "        if errorMerge < errorNoMerge: \n",
    "            print(\"merging\")\n",
    "            return treeMean\n",
    "        # 反之，不执行合并\n",
    "        else: return tree\n",
    "    else: return tree"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n",
      "merging\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'spInd': 0,\n",
       " 'spVal': 0.499171,\n",
       " 'left': {'spInd': 0,\n",
       "  'spVal': 0.729397,\n",
       "  'left': {'spInd': 0,\n",
       "   'spVal': 0.952833,\n",
       "   'left': {'spInd': 0,\n",
       "    'spVal': 0.965969,\n",
       "    'left': 92.5239915,\n",
       "    'right': {'spInd': 0,\n",
       "     'spVal': 0.956951,\n",
       "     'left': {'spInd': 0,\n",
       "      'spVal': 0.958512,\n",
       "      'left': {'spInd': 0,\n",
       "       'spVal': 0.960398,\n",
       "       'left': 112.386764,\n",
       "       'right': 123.559747},\n",
       "      'right': 135.837013},\n",
       "     'right': 111.2013225}},\n",
       "   'right': {'spInd': 0,\n",
       "    'spVal': 0.759504,\n",
       "    'left': {'spInd': 0,\n",
       "     'spVal': 0.763328,\n",
       "     'left': {'spInd': 0,\n",
       "      'spVal': 0.769043,\n",
       "      'left': {'spInd': 0,\n",
       "       'spVal': 0.790312,\n",
       "       'left': {'spInd': 0,\n",
       "        'spVal': 0.806158,\n",
       "        'left': {'spInd': 0,\n",
       "         'spVal': 0.815215,\n",
       "         'left': {'spInd': 0,\n",
       "          'spVal': 0.833026,\n",
       "          'left': {'spInd': 0,\n",
       "           'spVal': 0.841547,\n",
       "           'left': {'spInd': 0,\n",
       "            'spVal': 0.841625,\n",
       "            'left': {'spInd': 0,\n",
       "             'spVal': 0.944221,\n",
       "             'left': {'spInd': 0,\n",
       "              'spVal': 0.948822,\n",
       "              'left': 96.41885225,\n",
       "              'right': 69.318649},\n",
       "             'right': {'spInd': 0,\n",
       "              'spVal': 0.85497,\n",
       "              'left': {'spInd': 0,\n",
       "               'spVal': 0.936524,\n",
       "               'left': 110.03503850000001,\n",
       "               'right': {'spInd': 0,\n",
       "                'spVal': 0.934853,\n",
       "                'left': 65.548418,\n",
       "                'right': {'spInd': 0,\n",
       "                 'spVal': 0.925782,\n",
       "                 'left': 115.753994,\n",
       "                 'right': {'spInd': 0,\n",
       "                  'spVal': 0.910975,\n",
       "                  'left': {'spInd': 0,\n",
       "                   'spVal': 0.912161,\n",
       "                   'left': 94.3961145,\n",
       "                   'right': 85.005351},\n",
       "                  'right': {'spInd': 0,\n",
       "                   'spVal': 0.901444,\n",
       "                   'left': {'spInd': 0,\n",
       "                    'spVal': 0.908629,\n",
       "                    'left': 106.814667,\n",
       "                    'right': 118.513475},\n",
       "                   'right': {'spInd': 0,\n",
       "                    'spVal': 0.901421,\n",
       "                    'left': 87.300625,\n",
       "                    'right': {'spInd': 0,\n",
       "                     'spVal': 0.892999,\n",
       "                     'left': {'spInd': 0,\n",
       "                      'spVal': 0.900699,\n",
       "                      'left': 100.133819,\n",
       "                      'right': 108.094934},\n",
       "                     'right': {'spInd': 0,\n",
       "                      'spVal': 0.888426,\n",
       "                      'left': 82.436686,\n",
       "                      'right': {'spInd': 0,\n",
       "                       'spVal': 0.872199,\n",
       "                       'left': 98.54454949999999,\n",
       "                       'right': 106.16859550000001}}}}}}}}},\n",
       "              'right': {'spInd': 0,\n",
       "               'spVal': 0.84294,\n",
       "               'left': {'spInd': 0,\n",
       "                'spVal': 0.847219,\n",
       "                'left': 89.20993,\n",
       "                'right': 76.240984},\n",
       "               'right': 95.893131}}},\n",
       "            'right': 60.552308},\n",
       "           'right': 124.87935300000001},\n",
       "          'right': {'spInd': 0,\n",
       "           'spVal': 0.823848,\n",
       "           'left': 76.723835,\n",
       "           'right': {'spInd': 0,\n",
       "            'spVal': 0.819722,\n",
       "            'left': 59.342323,\n",
       "            'right': 70.054508}}},\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.811602,\n",
       "          'left': 118.319942,\n",
       "          'right': {'spInd': 0,\n",
       "           'spVal': 0.811363,\n",
       "           'left': 99.841379,\n",
       "           'right': 112.981216}}},\n",
       "        'right': 73.49439925},\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.786865,\n",
       "        'left': 114.4008695,\n",
       "        'right': 102.26514075}},\n",
       "      'right': 64.041941},\n",
       "     'right': 115.199195},\n",
       "    'right': 78.08564325}},\n",
       "  'right': {'spInd': 0,\n",
       "   'spVal': 0.640515,\n",
       "   'left': {'spInd': 0,\n",
       "    'spVal': 0.642373,\n",
       "    'left': {'spInd': 0,\n",
       "     'spVal': 0.642707,\n",
       "     'left': {'spInd': 0,\n",
       "      'spVal': 0.665329,\n",
       "      'left': {'spInd': 0,\n",
       "       'spVal': 0.706961,\n",
       "       'left': {'spInd': 0,\n",
       "        'spVal': 0.70889,\n",
       "        'left': {'spInd': 0,\n",
       "         'spVal': 0.716211,\n",
       "         'left': 110.90283,\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.710234,\n",
       "          'left': 103.345308,\n",
       "          'right': 108.553919}},\n",
       "        'right': 135.416767},\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.698472,\n",
       "        'left': {'spInd': 0,\n",
       "         'spVal': 0.69892,\n",
       "         'left': {'spInd': 0,\n",
       "          'spVal': 0.699873,\n",
       "          'left': {'spInd': 0,\n",
       "           'spVal': 0.70639,\n",
       "           'left': 106.180427,\n",
       "           'right': 105.062147},\n",
       "          'right': 115.586605},\n",
       "         'right': 92.470636},\n",
       "        'right': {'spInd': 0,\n",
       "         'spVal': 0.689099,\n",
       "         'left': 120.521925,\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.666452,\n",
       "          'left': 101.91115275,\n",
       "          'right': 112.78136649999999}}}},\n",
       "      'right': {'spInd': 0,\n",
       "       'spVal': 0.661073,\n",
       "       'left': 121.980607,\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.652462,\n",
       "        'left': 115.687524,\n",
       "        'right': 112.715799}}},\n",
       "     'right': 82.500766},\n",
       "    'right': 140.613941},\n",
       "   'right': {'spInd': 0,\n",
       "    'spVal': 0.613004,\n",
       "    'left': {'spInd': 0,\n",
       "     'spVal': 0.623909,\n",
       "     'left': {'spInd': 0,\n",
       "      'spVal': 0.628061,\n",
       "      'left': {'spInd': 0,\n",
       "       'spVal': 0.637999,\n",
       "       'left': 82.713621,\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.632691,\n",
       "        'left': 91.656617,\n",
       "        'right': 93.645293}},\n",
       "      'right': {'spInd': 0,\n",
       "       'spVal': 0.624827,\n",
       "       'left': 117.628346,\n",
       "       'right': 105.970743}},\n",
       "     'right': 82.04976400000001},\n",
       "    'right': {'spInd': 0,\n",
       "     'spVal': 0.606417,\n",
       "     'left': 168.180746,\n",
       "     'right': {'spInd': 0,\n",
       "      'spVal': 0.513332,\n",
       "      'left': {'spInd': 0,\n",
       "       'spVal': 0.533511,\n",
       "       'left': {'spInd': 0,\n",
       "        'spVal': 0.548539,\n",
       "        'left': {'spInd': 0,\n",
       "         'spVal': 0.553797,\n",
       "         'left': {'spInd': 0,\n",
       "          'spVal': 0.560301,\n",
       "          'left': {'spInd': 0,\n",
       "           'spVal': 0.599142,\n",
       "           'left': 93.521396,\n",
       "           'right': {'spInd': 0,\n",
       "            'spVal': 0.589806,\n",
       "            'left': 130.378529,\n",
       "            'right': {'spInd': 0,\n",
       "             'spVal': 0.582311,\n",
       "             'left': 111.9849935,\n",
       "             'right': {'spInd': 0,\n",
       "              'spVal': 0.571214,\n",
       "              'left': 82.589328,\n",
       "              'right': {'spInd': 0,\n",
       "               'spVal': 0.569327,\n",
       "               'left': 114.872056,\n",
       "               'right': 108.435392}}}}},\n",
       "          'right': 82.903945},\n",
       "         'right': 129.0624485},\n",
       "        'right': {'spInd': 0,\n",
       "         'spVal': 0.546601,\n",
       "         'left': 83.114502,\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.537834,\n",
       "          'left': 97.3405265,\n",
       "          'right': 90.995536}}},\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.51915,\n",
       "        'left': {'spInd': 0,\n",
       "         'spVal': 0.531944,\n",
       "         'left': 129.766743,\n",
       "         'right': 124.795495},\n",
       "        'right': 116.176162}},\n",
       "      'right': {'spInd': 0,\n",
       "       'spVal': 0.508548,\n",
       "       'left': 101.075609,\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.508542,\n",
       "        'left': 93.292829,\n",
       "        'right': 96.403373}}}}}}},\n",
       " 'right': {'spInd': 0,\n",
       "  'spVal': 0.457563,\n",
       "  'left': {'spInd': 0,\n",
       "   'spVal': 0.465561,\n",
       "   'left': {'spInd': 0,\n",
       "    'spVal': 0.467383,\n",
       "    'left': {'spInd': 0,\n",
       "     'spVal': 0.483803,\n",
       "     'left': {'spInd': 0,\n",
       "      'spVal': 0.487381,\n",
       "      'left': 8.53677,\n",
       "      'right': 27.729263},\n",
       "     'right': 5.224234},\n",
       "    'right': {'spInd': 0,\n",
       "     'spVal': 0.46568,\n",
       "     'left': -9.712925,\n",
       "     'right': -23.777531}},\n",
       "   'right': {'spInd': 0,\n",
       "    'spVal': 0.463241,\n",
       "    'left': 30.051931,\n",
       "    'right': 17.171057}},\n",
       "  'right': {'spInd': 0,\n",
       "   'spVal': 0.455761,\n",
       "   'left': -34.044555,\n",
       "   'right': {'spInd': 0,\n",
       "    'spVal': 0.126833,\n",
       "    'left': {'spInd': 0,\n",
       "     'spVal': 0.130626,\n",
       "     'left': {'spInd': 0,\n",
       "      'spVal': 0.382037,\n",
       "      'left': {'spInd': 0,\n",
       "       'spVal': 0.388789,\n",
       "       'left': {'spInd': 0,\n",
       "        'spVal': 0.437652,\n",
       "        'left': -4.1911745,\n",
       "        'right': {'spInd': 0,\n",
       "         'spVal': 0.412516,\n",
       "         'left': {'spInd': 0,\n",
       "          'spVal': 0.418943,\n",
       "          'left': {'spInd': 0,\n",
       "           'spVal': 0.426711,\n",
       "           'left': {'spInd': 0,\n",
       "            'spVal': 0.428582,\n",
       "            'left': 19.745224,\n",
       "            'right': 15.224266},\n",
       "           'right': -21.594268},\n",
       "          'right': 44.161493},\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.403228,\n",
       "          'left': -26.419289,\n",
       "          'right': 0.6359300000000001}}},\n",
       "       'right': 23.197474},\n",
       "      'right': {'spInd': 0,\n",
       "       'spVal': 0.335182,\n",
       "       'left': {'spInd': 0,\n",
       "        'spVal': 0.370042,\n",
       "        'left': {'spInd': 0,\n",
       "         'spVal': 0.378965,\n",
       "         'left': -29.007783,\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.373501,\n",
       "          'left': {'spInd': 0,\n",
       "           'spVal': 0.377383,\n",
       "           'left': 13.583555,\n",
       "           'right': 5.241196},\n",
       "          'right': -8.228297}},\n",
       "        'right': {'spInd': 0,\n",
       "         'spVal': 0.35679,\n",
       "         'left': -32.124495,\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.350725,\n",
       "          'left': -9.9938275,\n",
       "          'right': -26.851234812500003}}},\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.324274,\n",
       "        'left': 22.286959625,\n",
       "        'right': {'spInd': 0,\n",
       "         'spVal': 0.309133,\n",
       "         'left': {'spInd': 0,\n",
       "          'spVal': 0.310956,\n",
       "          'left': -20.3973335,\n",
       "          'right': -49.939516},\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.131833,\n",
       "          'left': {'spInd': 0,\n",
       "           'spVal': 0.138619,\n",
       "           'left': {'spInd': 0,\n",
       "            'spVal': 0.156067,\n",
       "            'left': {'spInd': 0,\n",
       "             'spVal': 0.166765,\n",
       "             'left': {'spInd': 0,\n",
       "              'spVal': 0.193282,\n",
       "              'left': {'spInd': 0,\n",
       "               'spVal': 0.211633,\n",
       "               'left': {'spInd': 0,\n",
       "                'spVal': 0.228473,\n",
       "                'left': {'spInd': 0,\n",
       "                 'spVal': 0.25807,\n",
       "                 'left': {'spInd': 0,\n",
       "                  'spVal': 0.284794,\n",
       "                  'left': {'spInd': 0,\n",
       "                   'spVal': 0.300318,\n",
       "                   'left': 8.814725,\n",
       "                   'right': {'spInd': 0,\n",
       "                    'spVal': 0.297107,\n",
       "                    'left': -18.051318,\n",
       "                    'right': {'spInd': 0,\n",
       "                     'spVal': 0.295993,\n",
       "                     'left': -1.798377,\n",
       "                     'right': {'spInd': 0,\n",
       "                      'spVal': 0.290749,\n",
       "                      'left': -14.988279,\n",
       "                      'right': -14.391613}}}},\n",
       "                  'right': {'spInd': 0,\n",
       "                   'spVal': 0.273863,\n",
       "                   'left': 35.623746,\n",
       "                   'right': {'spInd': 0,\n",
       "                    'spVal': 0.264926,\n",
       "                    'left': -9.457556,\n",
       "                    'right': {'spInd': 0,\n",
       "                     'spVal': 0.264639,\n",
       "                     'left': 5.280579,\n",
       "                     'right': 2.557923}}}},\n",
       "                 'right': {'spInd': 0,\n",
       "                  'spVal': 0.228628,\n",
       "                  'left': {'spInd': 0,\n",
       "                   'spVal': 0.228751,\n",
       "                   'left': -9.601409499999999,\n",
       "                   'right': -30.812912},\n",
       "                  'right': -2.266273}},\n",
       "                'right': 6.099239},\n",
       "               'right': {'spInd': 0,\n",
       "                'spVal': 0.202161,\n",
       "                'left': -16.42737025,\n",
       "                'right': -2.6781805}},\n",
       "              'right': 9.5773855},\n",
       "             'right': {'spInd': 0,\n",
       "              'spVal': 0.156273,\n",
       "              'left': {'spInd': 0,\n",
       "               'spVal': 0.164134,\n",
       "               'left': {'spInd': 0,\n",
       "                'spVal': 0.166431,\n",
       "                'left': -14.740059,\n",
       "                'right': -6.512506},\n",
       "               'right': -27.405211},\n",
       "              'right': 0.225886}},\n",
       "            'right': {'spInd': 0,\n",
       "             'spVal': 0.13988,\n",
       "             'left': 7.557349,\n",
       "             'right': 7.336784}},\n",
       "           'right': -29.087463},\n",
       "          'right': 22.478291}}}}},\n",
       "     'right': -39.524461},\n",
       "    'right': {'spInd': 0,\n",
       "     'spVal': 0.124723,\n",
       "     'left': 22.891675,\n",
       "     'right': {'spInd': 0,\n",
       "      'spVal': 0.085111,\n",
       "      'left': {'spInd': 0,\n",
       "       'spVal': 0.108801,\n",
       "       'left': 6.196516,\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.10796,\n",
       "        'left': -16.106164,\n",
       "        'right': {'spInd': 0,\n",
       "         'spVal': 0.085873,\n",
       "         'left': -1.293195,\n",
       "         'right': -10.137104}}},\n",
       "      'right': {'spInd': 0,\n",
       "       'spVal': 0.084661,\n",
       "       'left': 37.820659,\n",
       "       'right': {'spInd': 0,\n",
       "        'spVal': 0.080061,\n",
       "        'left': -24.132226,\n",
       "        'right': {'spInd': 0,\n",
       "         'spVal': 0.068373,\n",
       "         'left': 15.824970500000001,\n",
       "         'right': {'spInd': 0,\n",
       "          'spVal': 0.061219,\n",
       "          'left': -15.160836,\n",
       "          'right': {'spInd': 0,\n",
       "           'spVal': 0.044737,\n",
       "           'left': {'spInd': 0,\n",
       "            'spVal': 0.053764,\n",
       "            'left': {'spInd': 0,\n",
       "             'spVal': 0.055862,\n",
       "             'left': 6.695567,\n",
       "             'right': -3.131497},\n",
       "            'right': -13.731698},\n",
       "           'right': 4.091626}}}}}}}}}}}"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myDat2 = loadDataSet(\"ex2.txt\")\n",
    "myMat2 = mat(myDat2)\n",
    "myTree = createTree(myMat2, ops=(0, 1))\n",
    "myDatTest = loadDataSet(\"ex2test.txt\")\n",
    "myMat2Test = mat(myDatTest)\n",
    "prune(myTree, myMat2Test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.5.模型树"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linearSolve(dataSet):\n",
    "    \"\"\"\n",
    "    模型树节点生成函数\n",
    "    参数：\n",
    "        dataSet -- 数据集\n",
    "    返回：\n",
    "        ws -- 回归参数\n",
    "        X -- x数据\n",
    "        Y -- y数据\n",
    "    \"\"\"\n",
    "    m,n = shape(dataSet)\n",
    "    X = mat(ones((m,n))); Y = mat(ones((m,1)))\n",
    "    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]\n",
    "    xTx = X.T*X\n",
    "    if linalg.det(xTx) == 0.0:\n",
    "        raise NameError('This matrix is singular, cannot do inverse,\\n\\\n",
    "        try increasing the second value of ops')\n",
    "    ws = xTx.I * (X.T * Y)\n",
    "    return ws,X,Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modelLeaf(dataSet):\n",
    "    ws,X,Y = linearSolve(dataSet)\n",
    "    return ws"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modelErr(dataSet):\n",
    "    ws,X,Y = linearSolve(dataSet)\n",
    "    yHat = X * ws\n",
    "    return sum(power(Y - yHat,2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],\n",
       "         [1.19647739e+01]]), 'right': matrix([[3.46877936],\n",
       "         [1.18521743]])}"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myMat2 = mat(loadDataSet(\"exp2.txt\"))\n",
    "createTree(myMat2, modelLeaf, modelErr, (1, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.6.用树回归进行预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "def regTreeEval(model, inDat):\n",
    "    return float(model)\n",
    "\n",
    "def modelTreeEval(model, inDat):\n",
    "    n = shape(inDat)[1]\n",
    "    X = mat(ones((1,n+1)))\n",
    "    X[:,1:n+1]=inDat\n",
    "    return float(X*model)\n",
    "\n",
    "def treeForeCast(tree, inData, modelEval=regTreeEval):\n",
    "    if not isTree(tree): return modelEval(tree, inData)\n",
    "    if inData[tree['spInd']] > tree['spVal']:\n",
    "        if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)\n",
    "        else: return modelEval(tree['left'], inData)\n",
    "    else:\n",
    "        if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)\n",
    "        else: return modelEval(tree['right'], inData)\n",
    "        \n",
    "def createForeCast(tree, testData, modelEval=regTreeEval):\n",
    "    m=len(testData)\n",
    "    yHat = mat(zeros((m,1)))\n",
    "    for i in range(m):\n",
    "        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)\n",
    "    return yHat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9640852318222141"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainMat = mat(loadDataSet(\"bikeSpeedVsIq_train.txt\"))\n",
    "testMat = mat(loadDataSet(\"bikeSpeedVsIq_test.txt\"))\n",
    "myTree = createTree(trainMat, ops=(1,20))\n",
    "yHat = createForeCast(myTree, testMat[:,0])\n",
    "corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9760412191380593"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myTree = createTree(trainMat, modelLeaf, modelErr, ops=(1,20))\n",
    "yHat = createForeCast(myTree, testMat[:,0], modelTreeEval)\n",
    "corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[37.58916794],\n",
       "        [ 6.18978355]])"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ws, X, Y = linearSolve(trainMat)\n",
    "ws"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9434684235674766"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(shape(testMat)[0]):\n",
    "    yHat[i] = testMat[i, 0] * ws[1, 0] + ws[0, 0]\n",
    "\n",
    "corrcoef(yHat, testMat[:,1], rowvar=0)[0, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
